import torch
from torchstat import stat
from models.MIMOUNet import build_net
# model_path = "/home/jyz/wxb_ws/MIMO-UNet/results/MIMO-UNet/exp_test/weights/epoch_474_iou_0.619.pkl"
model_path = "/home/jyz/wxb_ws/MIMO-UNet/results/MIMO-UNet-SNN/exp_test_ts_2/weights/epoch_473_iou_0.589.pkl"
state = torch.load(model_path)
# model = build_net("MIMO-UNet")
model = build_net("MIMO-UNet-SNN")
model.load_state_dict(state['model'])
# 导入模型，输入一张输入图片的尺寸
stat(model, (1, 3, 360, 480))
